# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -------------------------------------------------------------

# Autogenerated By   : src/main/python/generator/generator.py
# Autogenerated From : scripts/builtin/shapExplainer.dml

from typing import Dict, Iterable

from systemds.operator import OperationNode, Matrix, Frame, List, MultiReturn, Scalar
from systemds.utils.consts import VALID_INPUT_TYPES


def shapExplainer(model_function: str,
                  model_args: List,
                  x_instances: Matrix,
                  X_bg: Matrix,
                  **kwargs: Dict[str, VALID_INPUT_TYPES]):
    """
     Computes shapley values for multiple instances in parallel using antithetic permutation sampling.
     The resulting matrix phis holds the shapley values for each feature in the column given by the index of the feature in the sample.
    
     This method first creates two large matrices for masks and masked background data for all permutations and
     then runs in paralell on all instances in x.
     While the prepared matrices can become very large (2 * #features * #permuations * #n_samples * #features),
     the preparation of a row for the model call breaks down to a single element-wise multiplication of this mask with the row and
     an addition to the masked background data, since masks can be reused for each instance.
    
    
    
    :param model_function: The function of the model to be evaluated as a String. This function has to take a matrix of samples
        and return a vector of predictions.
        It might be usefull to wrap the model into a function the takes and returns the desired shapes and
        use this wrapper here.
    :param model_args: Arguments in order for the model, if desired. This will be prepended by the created instances-matrix.
    :param x_instances: Multiple instances as rows for which to compute the shapley values.
    :param X_bg: The background dataset from which to pull the random samples to perform Monte Carlo integration.
    :param n_permutations: The number of permutaions. Defaults to 10. Theoretical 1 should already be enough for models with up
        to second order interaction effects.
    :param n_samples: Number of samples from X_bg used for marginalization.
    :param remove_non_var: EXPERIMENTAL: If set, for every instance the varaince of each feature is checked against this feature in the
        background data. If it does not change, we do not run any model cals for it.
    :param seed: A seed, in case the sampling has to be deterministic.
    :param verbose: A boolean to enable logging of each step of the function.
    :return: Matrix holding the shapley values along the cols, one row per instance.
    :return: Double holding the average prediction of all instances.
    """

    params_dict = {'model_function': model_function, 'model_args': model_args, 'x_instances': x_instances, 'X_bg': X_bg}
    params_dict.update(kwargs)
    
    vX_0 = Matrix(model_function.sds_context, '')
    vX_1 = Scalar(model_function.sds_context, '')
    output_nodes = [vX_0, vX_1, ]

    op = MultiReturn(model_function.sds_context, 'shapExplainer', output_nodes, named_input_nodes=params_dict)

    vX_0._unnamed_input_nodes = [op]
    vX_1._unnamed_input_nodes = [op]

    return op
